# import openai
import os
import re
import random
import httpx

import torch
import numpy as np
import time
from rich import print as rprint
# from math import prod

import cohere
from typing import Union
# from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
from transformers import StoppingCriteria, StoppingCriteriaList
from utils_causal import load_causal_graph


# from .utils import convert_messages_to_prompt, retry_with_exponential_backoff
from .utils import convert_messages_to_prompt, remove_redundant_text, generate_action_index, set_seed

set_seed(1)

load_dotenv("./proagent/secrets.env")

cohere_api_key = os.getenv("API_KEY")

# Refer to https://platform.openai.com/docs/models/overview
TOKEN_LIMIT_TABLE = {
    "text-davinci-003": 4080,
    "gpt-3.5-turbo": 4096,
    "gpt-3.5-turbo-0301": 4096,
    "gpt-3.5-turbo-16k": 16384,
    "gpt-4": 8192,
    "gpt-4-0314": 8192,
    "gpt-4-32k": 32768,
    "gpt-4-32k-0314": 32768,
}

OvercookedState = ['empty_hand', # empty hand 0
                'hold_onion', # holding oninon 1
                'hold_dish', # holding empty dish 2
                'dish_with_soup', # holding dish with soup 3
                'pot_0', # pot with 0 onion 4
                'pot_1', # pot with 1 onion 5
                'pot_2', # pot with 2 onions 6
                'pot_3', # pot with 3 oninons 7
                'pot_finished', # pot with soup cooked 8
                'goal_delivered', # delivered_goal 9
                'pickup(onion)', # 10
                'put_onion_in_pot()', # 11
                'pickup(dish)', # 12
                'fill_dish_with_soup()', # 13
                'deliver_soup()', # 14
                'place_onion_on_counter()', # 15
                'place_dish_on_counter()',] # 16


OvercookedState_2_pot = ['empty_hand', # empty hand 0
                'hold_onion', # holding oninon 1
                'hold_dish', # holding empty dish 2
                'dish_with_soup', # holding dish with soup 3
                'pot_0', # pot with 0 onion 4
                'pot_1', # pot with 1 onion 5
                'pot_2', # pot with 2 onions 6
                'pot_3', # pot with 3 oninons 7
                'pot_finished', # pot with soup cooked 8
                'pot_1_0', # pot with 0 onion 9
                'pot_1_1', # pot with 1 onion 10
                'pot_1_2', # pot with 2 onions 11
                'pot_1_3', # pot with 3 oninons 12
                'pot_1_finished', # pot with soup cooked 13
                'goal_delivered', # delivered_goal 14
                'pickup(onion)', # 15
                'put_onion_in_pot()', # 16
                'pickup(dish)', # 17
                'fill_dish_with_soup()', # 18
                'deliver_soup()', # 19
                'place_onion_on_counter()', # 20
                'place_dish_on_counter()',] # 21


# Define a custom stopping criterion
class StopOnWordCriteria(StoppingCriteria):
    def __init__(self, tokenizer, stopword, input_length):
        self.tokenizer = tokenizer
        self.stopword_id = tokenizer.convert_tokens_to_ids(stopword)
        self.input_length = input_length

    def __call__(self, input_ids, scores, **kwargs):
        # Check if the stopword token exists in the latest generated tokens
        for seq in input_ids:
            if self.stopword_id in seq[self.input_length:]:
                return True
        return False



class Module(object):
    """
    This module is responsible for communicating with GPTs.
    """
    def __init__(self, 
                 role_messages, 
                 model="gpt-3.5-turbo-0301",
                 retrival_method="recent_k",
                 K=3, 
                 causal_graph=None):
        '''
        args:  
        use_similarity: 
        dia_num: the num of dia use need retrival from dialog history
        '''

        self.model = model
        self.retrival_method = retrival_method
        self.K = K

        self.chat_model = True if "gpt" in self.model else False
        self.instruction_head_list = role_messages
        self.dialog_history_list = []
        self.current_user_message = None
        self.cache_list = None
        if "command" in self.model:
            self.co = cohere.ClientV2(cohere_api_key)
        elif "meta-llama" in self.model or "google" in self.model or "Qwen" in self.model:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model)
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                self.model,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                )
        
        self.causal_graph = causal_graph

    def add_msgs_to_instruction_head(self, messages: Union[list, dict]):
        if isinstance(messages, list):
            self.instruction_head_list += messages
        elif isinstance(messages, dict):
            self.instruction_head_list += [messages]

    def add_msg_to_dialog_history(self, message: dict):
        self.dialog_history_list.append(message)
    
    def get_cache(self)->list:
        if self.retrival_method == "recent_k":
            if self.K > 0:
                return self.dialog_history_list[-self.K:]
            else: 
                return []
        else:
            return None 
           
    @property
    def query_messages(self)->list:
        return self.instruction_head_list + self.cache_list + [self.current_user_message]
    
    # @retry_with_exponential_backoff
    def query(self, key=None, stop=None, temperature=0.0, debug_mode = 'Y', trace = True, count=0):
        rec = self.K  
        if trace == True: 
            self.K = 0 
        self.cache_list = self.get_cache()
        messages = self.query_messages
        if trace == False: 
            messages[len(messages) - 1]['content'] += " Based on the failure explanation and scene description, analyze and plan again." 
        self.K = rec 
        response = "" 
        get_response = False
        retry_count = 0

        # print("Message", messages)

        while not get_response:  
            if retry_count > 3:
                rprint("[red][ERROR][/red]: Query GPT failed for over 3 times!")
                return {}

            elif "meta-llama" in self.model or "google" in self.model or "Qwen" in self.model:
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

                if  messages[0]['role'] == "system":
                    if "meta-llama" in self.model or "google" in self.model:
                        messages[1]['content'] = "<system> " + messages[0]['content'] + "</system>\n" + messages[1]['content'] 
                    else:
                        messages[1]['content'] = messages[0]['content'] + messages[1]['content']
                    messages = messages[1]['content']
                else:
                    messages = messages[1]['content']

                input_ids = self.tokenizer(messages, return_tensors="pt").to(self.llm_model.device)

                if stop:
                    stopword = stop
                    stopping_criteria = StoppingCriteriaList([StopOnWordCriteria(self.tokenizer, stopword, input_length=input_ids["input_ids"].shape[1])])
                    outputs = self.llm_model.generate(**input_ids, 
                                        max_new_tokens=256, 
                                        do_sample=False,
                                        return_dict_in_generate=True,
                                        output_scores=True,
                                        stopping_criteria=stopping_criteria)
                else:
                    outputs = self.llm_model.generate(**input_ids, 
                                        max_new_tokens=256, 
                                        do_sample=False,
                                        return_dict_in_generate=True,
                                        output_scores=True)

                input_length = input_ids["input_ids"].shape[1]

                generated_tokens = outputs.sequences[:, input_length:]

                response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

                return response
            
            elif 'command' in self.model:
                for attempt in range(5):
                    try:
                        response = self.co.chat(
                            model=self.model, 
                            messages=messages,
                            max_tokens=256,
                            temperature=temperature
                        )
                        break  # If successful, exit the loop
                    except httpx.ReadTimeout:
                        print(f"Timeout occurred. Retrying {attempt + 1}/{3}...")
                        time.sleep(5)
                else:
                    raise Exception("API request failed after multiple retries.")

            else:
                raise Exception(f"Model {self.model} not supported.")
            
            get_response = True

        return self.parse_response(response)


    
    def get_new_prob(self, action_index, index, llm_prob, alpha=0.3):
        try:
            indices = [i for i, x in enumerate(index) if x == 1]
                
            probs_list = [self.causal_graph[action_index][index] for index in indices] 

            llm_prob = llm_prob + sum(probs_list)
            return llm_prob
        except:
            print("Fail")
            return llm_prob
        
    def get_new_prob_w_op_index(self, action_index, index, op_index, llm_prob, tuning, gamma=0.5, layout=None):
        if not tuning:
            if layout == "cramped_room":
                try:
                    indices = [i for i, x in enumerate(index) if x == 1]
                    # Plus 17 state of current agent state and action --> Op indice 0 will turn to 17 in the matrix, 1 --> 18
                    op_indices = [i + 17 for i, x in enumerate(op_index) if x == 1] 

                    indices = indices + op_indices
                    probs_list = [self.causal_graph[action_index][index] for index in indices] 
                    llm_prob = llm_prob + sum(probs_list)
                    return llm_prob
                except Exception as e:
                    # print("Fail:", e)  # Print the actual error message
                    return llm_prob
            else:
                try:
                    indices = [i for i, x in enumerate(index) if x == 1]
                    # Plus 22 state of current agent state and action --> Op indice 0 will turn to 22 in the matrix, 1 --> 23
                    op_indices = [i + 22 for i, x in enumerate(op_index) if x == 1] 

                    indices = indices + op_indices
                    probs_list = [self.causal_graph[action_index][index] for index in indices] 
                    llm_prob = llm_prob + sum(probs_list)
                    return llm_prob
                except Exception as e:
                    # print("Fail:", e)  # Print the actual error message
                    return llm_prob
        else:
            if layout == "cramped_room":
                try:
                    indices = [i for i, x in enumerate(index) if x == 1]
                    # Plus 17 state of current agent state and action --> Op indice 0 will turn to 17 in the matrix, 1 --> 18
                    op_indices = [i + 17 for i, x in enumerate(op_index) if x == 1] 

                    indices = indices + op_indices
                    probs_list = [self.causal_graph[action_index][index] for index in indices] 
                    llm_prob = gamma * llm_prob + (1-gamma) * sum(probs_list)
                    return llm_prob
                except Exception as e:
                    # print("Fail:", e)  # Print the actual error message
                    return llm_prob
            else:
                try:
                    indices = [i for i, x in enumerate(index) if x == 1]
                    # Plus 22 state of current agent state and action --> Op indice 0 will turn to 22 in the matrix, 1 --> 23
                    op_indices = [i + 22 for i, x in enumerate(op_index) if x == 1] 

                    indices = indices + op_indices
                    probs_list = [self.causal_graph[action_index][index] for index in indices] 
                    llm_prob = gamma * llm_prob + (1-gamma) * sum(probs_list)
                    return llm_prob
                except Exception as e:
                    # print("Fail:", e)  # Print the actual error message
                    return llm_prob

    
    def get_action_based_on_causal_graph(self, index, op_index, layout, alpha=0.3):
        try:
            indices = [i for i, x in enumerate(index) if x == 1]
            # Offset operation indices by 17
            if layout == "cramped_room":
                op_indices = [i + 17 for i, x in enumerate(op_index) if x == 1]

            else:
                op_indices = [i + 22 for i, x in enumerate(op_index) if x == 1]

            indices += op_indices  # Combine both lists


            # Collect probabilities with corresponding action indices
            probs_with_action_indices = [
                (self.causal_graph[action_index][idx], action_index)  # Store probability and action index
                for idx in indices
                for action_index in range(0, 7)  # Iterate over all possible actions
            ]

            if not probs_with_action_indices:
                return None, None

            # Find the max probability and corresponding action index
            max_prob, max_action_index = max(probs_with_action_indices, key=lambda x: x[0])
            print(max_prob, max_action_index)
            if layout == "cramped_room":
                return OvercookedState[max_action_index + 10], max_prob
            else:
                return OvercookedState_2_pot[max_action_index + 15], max_prob

        except Exception as e:
            print("Error:", e)  # Print actual error for debugging
            return None, None  # Return None in case of failure
        
    # @retry_with_exponential_backoff
    def query_w_cg(self, key=None, stop=None, temperature=0.0, debug_mode = 'Y', trace = True, agent_index = None, index=None,
                    op_index=None, failure_handled=False, tuning=False, gamma=0.5, layout=None):
        """Querying function with causal graph""" 

        rec = self.K  
        if trace == True: 
            self.K = 0 
        self.cache_list = self.get_cache()
        messages = self.query_messages
        if trace == False: 
            messages[len(messages) - 1]['content'] += " Based on the failure explanation and scene description, analyze and plan again." 
        self.K = rec 
        response = "" 
        get_response = False
        retry_count = 0

        while not get_response:  
            if retry_count > 3:
                rprint("[red][ERROR][/red]: Query GPT failed for over 3 times!")
                return {}

            if self.model in ['text-davinci-003']:
                prompt = convert_messages_to_prompt(messages) 
                response = openai.Completion.create(
                    model=self.model,
                    prompt=prompt,
                    stop=stop,
                    temperature=temperature, 
                    max_tokens = 256)


            elif "meta-llama" in self.model or "google" in self.model or "Qwen" in self.model:
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

                if  messages[0]['role'] == "system":
                    if "meta-llama" in self.model or "google" in self.model:
                        messages[1]['content'] = "<system> " + messages[0]['content'] + "</system>\n" + messages[1]['content'] 
                    else:
                        messages[1]['content'] = messages[0]['content'] + messages[1]['content']
                    messages = messages[1]['content']
                else:
                    messages = messages[1]['content']

                # print("Messages", messages)
                # exit()
                input_ids = self.tokenizer(messages, return_tensors="pt").to(self.llm_model.device)

                num_queries = 20

                batched_input_ids = {key: value.repeat(num_queries, 1) for key, value in input_ids.items()}
                
                if stop:
                    stopword = stop
                    stopping_criteria = StoppingCriteriaList([StopOnWordCriteria(self.tokenizer, stopword, input_length=input_ids["input_ids"].shape[1])])

                    outputs = self.llm_model.generate(**batched_input_ids, 
                                    max_new_tokens=256, 
                                    do_sample=True, # Enable sampling
                                    temperature=1, # Adjust randomness (higher = more random)
                                    top_k=50,  # Sample from the top-k tokens
                                    top_p=0.9,  # Nucleus sampling (alternative to top-k)
                                    return_dict_in_generate=True,
                                    output_scores=True,
                                    stopping_criteria=stopping_criteria)
                    
                else:
                    outputs = self.llm_model.generate(**batched_input_ids, 
                                    max_new_tokens=256, 
                                    do_sample=True, # Enable sampling
                                    temperature=1, # Adjust randomness (higher = more random)
                                    top_k=50,  # Sample from the top-k tokens
                                    top_p=0.9,  # Nucleus sampling (alternative to top-k)
                                    return_dict_in_generate=True,
                                    output_scores=True)
                
                generated_tokens = outputs.sequences[:, input_ids["input_ids"].shape[1]:]

                transition_scores = self.llm_model.compute_transition_scores(
                                    outputs.sequences, outputs.scores, normalize_logits=True)

                transition_scores = torch.where(transition_scores == float('-inf'), 0.0, transition_scores)

                responses = [self.tokenizer.decode(generated_token, skip_special_tokens=True) for generated_token in generated_tokens]

                action_index = [generate_action_index(response, index, layout) for response in responses]

                llm_probs_list = [torch.exp(transition_score.sum(axis=0)).cpu().numpy() for transition_score in transition_scores]

                probs_list = [self.get_new_prob_w_op_index(a_index, index, op_index, llm_prob, tuning, gamma, layout) for a_index, llm_prob in zip(action_index, llm_probs_list)]  
                # if layout == "forced_coordination":


                try:
                    final_responses, final_probs = remove_redundant_text(probs_list, responses, layout)
                except Exception as e:
                    print(e)
                    print("new prob list", probs_list)
                    print("Responses:", responses)
                    final_responses, final_probs = self.get_action_based_on_causal_graph(index, op_index, layout)
                    if final_responses is None:
                        return ""
                    else:
                        if failure_handled:
                            print("failure_handled")
                            best_resp = f"Plan for Player {agent_index}: {final_responses}."
                            print(f"Best response: {best_resp}")
                            return best_resp
                        else:
                            print("Fail: ", responses)
                            return ""

                # Print the final response and corresponding probability
                for resp, prob in zip(final_responses, final_probs):
                    print(f"Response: Plan for Player {agent_index}: {resp} | Probability: {prob}")
                best_resp = random.choices(final_responses, weights=final_probs, k=1)[0]

                if agent_index:
                    best_resp = f"Plan for Player {agent_index}: {best_resp}."
                print(f"Best response: {best_resp}")    
                return best_resp

            elif 'command' in self.model:
                for attempt in range(5):
                    try:
                        response = self.co.chat(
                            model=self.model, 
                            messages=messages,
                            max_tokens=256,
                            temperature=temperature
                        )
                        break  # Exit loop if successful
                    except httpx.ReadTimeout:
                        print(f"Timeout occurred. Retrying {attempt + 1}/{3}...")
                        time.sleep(5)
                else:
                    raise Exception("API request failed after multiple retries.")

                
            elif 'gpt' in self.model:
                response = openai.ChatCompletion.create(
                    model=self.model,
                    messages=messages,
                    stop=stop,
                    temperature=temperature, 
                    max_tokens = 256)
            else:
                raise Exception(f"Model {self.model} not supported.")
            
            get_response = True

        return self.parse_response(response)


    def parse_response(self, response):
        if self.model == 'claude': 
            return response 
        elif self.model in ['text-davinci-003']:
            return response["choices"][0]["text"]
        elif self.model in ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0301', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-0314']:
            return response["choices"][0]["message"]["content"]
        elif "meta-llama" in self.model or "google" in self.model or "Qwen" in self.model:
            return response[0]['generated_text']
        elif "command" in self.model: 
            return response.message.content[0].text

    def restrict_dialogue(self):
        """
        The limit on token length for gpt-3.5-turbo-0301 is 4096.
        If token length exceeds the limit, we will remove the oldest messages.
        """
        limit = TOKEN_LIMIT_TABLE[self.model]
        print(f'Current token: {self.prompt_token_length}')
        while self.prompt_token_length >= limit:
            self.cache_list.pop(0)
            self.cache_list.pop(0)
            self.cache_list.pop(0)
            self.cache_list.pop(0)
            print(f'Update token: {self.prompt_token_length}')
        
    def reset(self):
        self.dialog_history_list = []

